{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.nn as nn\n",
    "import torch\n",
    "import gym\n",
    "import numpy as np\n",
    "import copy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def cartpole_model(observation_space, action_space):\n",
    "    return nn.Sequential(\n",
    "        nn.Linear(observation_space, 128),\n",
    "        nn.ReLU(),\n",
    "        nn.Linear(128, action_space),\n",
    "        nn.Softmax(dim=1)\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def init_weight(module):\n",
    "    if((type(module) == nn.Linear)):\n",
    "            nn.init.xavier_uniform_(module.weight.data)\n",
    "            module.bias.data.fill_(0.00)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_agents(num_agents, observation_space, action_space):\n",
    "    agents = []\n",
    "    \n",
    "    for _ in range(num_agents):\n",
    "        agent = cartpole_model(observation_space, action_space)\n",
    "        agent.apply(init_weight)\n",
    "        \n",
    "        for param in agent.parameters():\n",
    "            param.requires_grad = False\n",
    "        \n",
    "        agent.eval()\n",
    "        agents.append(agent)\n",
    "        \n",
    "    return agents"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def eval_agent(agent, env):\n",
    "    observation = env.reset()\n",
    "    \n",
    "    total_reward = 0\n",
    "    for _ in range(MAX_STEP):\n",
    "        observation = torch.tensor(observation).type('torch.FloatTensor').view(1,-1)\n",
    "        action_probablity = agent(observation).detach().numpy()[0]\n",
    "        action = np.random.choice(range(env.action_space.n), 1, p=action_probablity).item()\n",
    "        next_observation, reward, terminal, _ = env.step(action)\n",
    "        total_reward += reward\n",
    "        observation = next_observation\n",
    "        \n",
    "        if terminal:\n",
    "            break\n",
    "\n",
    "    return total_reward"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def agent_score(agent, env, runs):\n",
    "    score = 0\n",
    "    for _ in range(runs):\n",
    "        score += eval_agent(agent, env)\n",
    "        \n",
    "    return score/runs "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def all_agent_score(agents, env, runs):\n",
    "    agents_score = []\n",
    "    for agent in agents:\n",
    "        agents_score.append(agent_score(agent, env, runs))\n",
    "    \n",
    "    return agents_score"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def mutate(agent):\n",
    "    child_agent = copy.deepcopy(agent)\n",
    "    \n",
    "    for param in agent.parameters():\n",
    "        mutation_noise = torch.randn_like(param) * MUTATION_POWER\n",
    "        param += mutation_noise\n",
    "        \n",
    "    return child_agent"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def elite(agents, top_parents_id, env, elite_id=None, top=10):\n",
    "    selected_elites = top_parents_id[:top]\n",
    "    \n",
    "    if elite_id:\n",
    "        selected_elites.append(elite_id)\n",
    "        \n",
    "    top_score = np.NINF\n",
    "    top_id = None\n",
    "    \n",
    "    for agent_id in selected_elites:\n",
    "        \n",
    "        score = agent_score(agents[agent_id], env, runs=5)\n",
    "        if score > top_score:\n",
    "            top_score = score\n",
    "            top_id = agent_id\n",
    "    \n",
    "    return copy.deepcopy(agents[top_id])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "def child_agents(agents, top_parents_id, env, elite_id=None):\n",
    "    children = []\n",
    "    \n",
    "    agent_count = len(agents)-1\n",
    "    \n",
    "    selected_agents_id = np.random.choice(top_parents_id, agent_count)\n",
    "    selected_agents = [agents[id] for id in selected_agents_id]\n",
    "    child_agents = [mutate(agent) for agent in selected_agents]\n",
    "    \n",
    "    child_agents.append(elite(agents, top_parents_id, env))\n",
    "    elite_id = len(child_agents)-1\n",
    "    \n",
    "    return child_agents, elite_id"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "def top_parents(scores, num_top_parents):\n",
    "    return np.argsort(rewards)[::-1][:num_top_parents]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "ENV_NAME = \"CartPole-v1\"\n",
    "MAX_STEP = 500\n",
    "MUTATION_POWER = 0.02\n",
    "\n",
    "num_agents = 500\n",
    "num_top_parents = 20\n",
    "generations = 25\n",
    "elite_agent = None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.set_grad_enabled(False)\n",
    "env = gym.make(ENV_NAME)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "agents = create_agents(num_agents, env.observation_space.shape[0], env.action_space.n)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "| Generation |     Score      |\n",
      "|    001     |    47.0667     |\n",
      "|    002     |    47.3333     |\n",
      "|    003     |    55.7333     |\n",
      "|    004     |    58.2667     |\n",
      "|    005     |    65.3333     |\n",
      "|    006     |    88.0000     |\n",
      "|    007     |    105.5333     |\n",
      "|    008     |    117.4000     |\n",
      "|    009     |    109.4000     |\n",
      "|    010     |    137.6667     |\n",
      "|    011     |    150.3333     |\n",
      "|    012     |    168.6000     |\n",
      "|    013     |    176.2667     |\n",
      "|    014     |    248.0667     |\n",
      "|    015     |    281.6667     |\n",
      "|    016     |    327.9333     |\n",
      "|    017     |    363.5333     |\n",
      "|    018     |    375.4000     |\n",
      "|    019     |    387.0000     |\n",
      "|    020     |    432.2000     |\n",
      "|    021     |    454.6000     |\n",
      "|    022     |    445.9333     |\n",
      "|    023     |    463.7333     |\n",
      "|    024     |    482.1333     |\n",
      "|    025     |    496.2000     |\n"
     ]
    }
   ],
   "source": [
    "print(f'| Generation |     Score      |')\n",
    "for gen in range(generations):\n",
    "    rewards = all_agent_score(agents, env, 3)\n",
    "    top_parents_id = top_parents(rewards, num_top_parents)\n",
    "    agents, elite_agent = child_agents(agents, top_parents_id, env, elite_agent)\n",
    "    print(f'|    {gen+1:03}     |    {np.mean([rewards[i] for i in top_parents_id[:5]]):.4f}     |')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "def play_agent(agent, env):\n",
    "    \n",
    "    observation = env.reset()\n",
    "    total_reward=0\n",
    "    \n",
    "    for _ in range(MAX_STEP):\n",
    "        env.render()\n",
    "        observation = torch.tensor(observation).type('torch.FloatTensor').view(1,-1)\n",
    "        output_probabilities = agent(observation).detach().numpy()[0]\n",
    "        action = np.random.choice(range(2), 1, p=output_probabilities).item()\n",
    "        new_observation, reward, done, _ = env.step(action)\n",
    "        total_reward += reward\n",
    "        observation = new_observation\n",
    "\n",
    "        if(done):\n",
    "            break\n",
    "\n",
    "    env.close()\n",
    "    print(\"Rewards: \",total_reward)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Rewards:  350.0\n"
     ]
    }
   ],
   "source": [
    "play_agent(agents[num_agents-1],env)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
